Inspired from the following tweets:
Basic idea:
# Add something to gradient
f(x) + g(x) - tf.stop_gradients(g(x))
# Reverse gradient
tf.stop_gradient(f(x)*2) - f(x)
In [1]:
import torch
import tensorflow as tf
from torch.autograd import Variable
import numpy as np
In [2]:
def f(X):
return X*X
def g(X):
return X**3
In [3]:
X = np.random.randn(10)
X
Out[3]:
In [4]:
sess = tf.InteractiveSession()
In [5]:
tf_X = tf.Variable(X)
init_op = tf.global_variables_initializer()
In [6]:
sess.run(init_op)
sess.run(tf_X)
Out[6]:
In [7]:
forward_op = f(tf_X)
In [8]:
sess.run(forward_op)
Out[8]:
In [9]:
gradient_op = tf.gradients(forward_op, tf_X)
In [10]:
sess.run(gradient_op)
Out[10]:
In [11]:
X*2 # This should match the gradient above
Out[11]:
Keep forward pass the same. The trick is to add $g(x)$, such that $g'(x)$ is the gradient modifier, during the forward pass and substract it as well. But stop gradients from flowing through the substraction part.
$f(x) + g(x) - g(x)$ will lead to gradients $f'(x) + g'(x) -g'(x)$. Since gradients don't flow through $-g'(x)$, hence we get new gradients as $f'(x) + g'(x)$
In [12]:
gradient_modifier_op = g(tf_X)
In [13]:
sess.run(gradient_modifier_op)
Out[13]:
In [14]:
modified_forward_op = (f(tf_X) + g(tf_X) - tf.stop_gradient(g(tf_X)))
modified_backward_op = tf.gradients(modified_forward_op, tf_X)
In [15]:
sess.run(modified_forward_op)
Out[15]:
In [16]:
sess.run(modified_backward_op)
Out[16]:
In [17]:
2*X + 3*(X**2) # This should match the gradients above
Out[17]:
In [18]:
gradient_reversal_op = (tf.stop_gradient(2*f(tf_X)) - f(tf_X))
gradient_reversal_grad_op = tf.gradients(gradient_reversal_op, tf_X)
In [19]:
sess.run(gradient_reversal_op)
Out[19]:
In [20]:
sess.run(gradient_reversal_grad_op)
Out[20]:
In [21]:
sess.run((gradient_op[0] + gradient_reversal_grad_op[0])) # This should be zero. Signifying grad is reversed.
Out[21]:
In [22]:
def zero_grad(X):
if X.grad is not None:
X.grad.data.zero_()
In [23]:
torch_X = Variable(torch.FloatTensor(X), requires_grad=True)
In [24]:
torch_X.data.numpy()
Out[24]:
In [25]:
f(torch_X).data.numpy()
Out[25]:
In [26]:
g(torch_X).data.numpy()
Out[26]:
In [27]:
zero_grad(torch_X)
f_X = f(torch_X)
f_X.backward(torch.ones(f_X.size()))
torch_X.grad.data.numpy()
Out[27]:
In [28]:
2*X
Out[28]:
In [29]:
modified_gradients_forward = lambda x: f(x) + g(x) - g(x).detach()
In [30]:
zero_grad(torch_X)
modified_grad = modified_gradients_forward(torch_X)
modified_grad.backward(torch.ones(modified_grad.size()))
torch_X.grad.data.numpy()
Out[30]:
In [31]:
2*X + 3*(X*X) # It should be same as above
Out[31]:
In [32]:
gradient_reversal = lambda x: (2*f(x)).detach() - f(x)
In [33]:
zero_grad(torch_X)
grad_reverse = gradient_reversal(torch_X)
grad_reverse.backward(torch.ones(grad_reverse.size()))
torch_X.grad.data.numpy()
Out[33]:
In [34]:
-2*X # It should be same as above
Out[34]:
In [35]:
# Gradient reversal
zero_grad(torch_X)
f_X = f(torch_X)
f_X.register_hook(lambda grad: -grad)
f_X.backward(torch.ones(f_X.size()))
torch_X.grad.data.numpy()
Out[35]:
In [36]:
-2*X
Out[36]:
In [37]:
# Modified grad example
zero_grad(torch_X)
h = torch_X.register_hook(lambda grad: grad + 3*(torch_X*torch_X))
f_X = f(torch_X)
f_X.backward(torch.ones(f_X.size()))
h.remove()
torch_X.grad.data.numpy()
Out[37]:
In [38]:
2*X + 3*(X*X) # It should be same as above
Out[38]:
In [ ]: